import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import os
import sys
import pickle
import yaml
import statsmodels.api as sm 
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from pprint import pprint
from statistics import mean
import random
#import pdb
random.seed(50)
import math
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import cm
import matplotlib.font_manager as fm
import matplotlib
import sklearn.metrics as metrics
import pdb
from scipy import stats
#sys.path.insert(1,'/REDACTED/fairness/code/utilities')
#sys.path.insert(1,'/REDACTED/fairness/code/config/')
#from configureColors import *
#sys.path.insert(1,'/REDACTED/fairness/code/utilities') 

#defaultout = '/REDACTED/output/outputForPaper/Trajectories_temp/'
defaultout = '/REDACTED/audit_descr_stats/data/'

# set plot defaults

plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

stanford_palette = ['#f0f4f5', '#d0d8da', '#aabeC6', '#009abb', '#007c92', '#09425a',
                   '#c7d1c6', '#80982a', '#556222', '#b96d12', '#53284f', '#5e3032',
                   '#8c1515', '#dddddd', '#cccccc', '#999999', '#666666', '#333333',
                   '#000000']
color_codes = ['background blue', 'gray blue', 'dark gray-blue', 'accent blue', 'link blue', 'dark blue',
              'light green', 'bright green', 'dark green', 'orange', 'purple', 'maroon',
              'cardinal red', 'line gray', 'light gray', 'gray', 'med gray', 'text gray',
              'black']
cdict = dict(zip(color_codes, [mcolors.to_rgba(c) for c in stanford_palette]))

mcolors.get_named_colors_mapping().update(cdict)

#####################################################
########### Pull in data ############################
#####################################################i
with open('/REDACTED/fairness/code/rf/data/status_quo.pickle','rb') as f:
    status_quo_point = pickle.load(f)
with open('/REDACTED/fairness/code/rf/data/res_rf_output_return_recalculated.pickle','rb') as f:
    restricted_rf = pickle.load(f)
with open('/REDACTED/fairness/code/rf/data/unres_rf_output.pickle','rb') as f:
    unrestricted_rf = pickle.load(f)


######################################################
############# Functions ##############################
######################################################
def get_meas_varb(plot_dict,measure='mean',varb='rev',demillion=True, ac_shares=False, activity_code=None):
    plot_dict = {key: ([value] if np.isscalar(value) == True else value) for key, value in plot_dict.items()}
    if ac_shares==True:
        varbs=[float(plot_dict[x][0]) for x in plot_dict if str(activity_code) in x]
    else:
        varb_keys = [x for x in plot_dict if measure+'_'+varb in x]
        varbs = [plot_dict[x][0] for x in varb_keys]
        varbs = [float(varbs[i]) for i in range(len(varbs))]
        if demillion==True and ('revenue' in varb or 'net_rev' in varb or 'cost' in varb or 'total_ref_adj' in varb):
            varbs = [varbs[i]/1000000 for i in range(len(varbs))]
    return varbs
### measure in {mean, std}, varb in {rev,fair}, activity_code in [270,282]

def makeOutcomeTable(status_quo_point, 
                     data_dict,
                     outdir,
                     constrained = True):
    
    if constrained:
        res = "res"
    else:
        res = "unres"

    table = pd.DataFrame(columns=['Metric', 'Regressor', 'Classifier', 'Oracle'])
    
    table['Metric'] = ['EITC Amount Decreased', 
                       'EITC Reduced to $0', 
                       'Mean EITC Adjustment',
                       'Single Filer',
                       'Married Filing Jointly',
                       'Head of Household',
                       'Filing Status Changed After Audit',
                       'Filing Status Changed from Head of Household',
                       'Count of Exemptions Decreased', 
                       'Count of Exemptions Changed', 
                       'Mean Change in Exemptions',
                       'AGI Increased', 
                       'Mean Change in AGI',
                       'Taxable Income Increased', 
                       'Mean Change in Taxable Income',
                       'EITC Decreased and Taxable Income Increased',
                       'No-Change Rate ($100)',
                       'Male PF',
                       'Mean AGI',
                       'Any SE Income (reported)',
                       'Mean SE Income (reported)'
]
    
    varbs=['eitc_amt_decr', 
           'eitc_amt_to_zero',
           'eitc_amt_dif',
           'filed_single',
           'filed_mfj',
           'filed_hoh',
           'filing_status_chg',
           'changed_from_hoh',
           'exemptn_total_decr', 
           'exemptn_total_chg', 
           'exemptn_total_dif',
           'agi_amt_incr',  
           'agi_amt_dif',
           'txbl_incm_incr', 
           'txbl_incm_amt_dif',
           'eitc_decr_and_txbl_inc_incr',
           'nochg_lt100',
           'male_pf', 
           'agi_amt',
           'any_se_inc', 
           'tot_se_inc' 
]
    
    idx=math.floor(status_quo_point['eitc_budget']*10000)-1
    print('audit budget is: ' + str(status_quo_point['eitc_budget']))
    print('corresponding index is: ' + str(idx))

    reg_list = [np.nan]*len(varbs)
    cls_list = [np.nan]*len(varbs)
    oracle_list = [np.nan]*len(varbs)
    
    for i in range(len(varbs)):
        reg_list[i] = data_dict[res + '_eitc_rf_dep_database'][varbs[i] + '_mean'][idx]
        cls_list[i] = data_dict[res + '_eitc_cls_dep_database'][varbs[i] + '_mean'][idx]
        oracle_list[i] = data_dict[res + '_eitc_oracle'][varbs[i] + '_mean'][idx]
    
    table['Regressor'] = reg_list
    table['Classifier'] = cls_list
    table['Oracle'] = oracle_list
    
    table = table.round(decimals=3)
    table.to_csv(outdir + 'line_item_chgs_' + res + '_models.csv')

    with open(outdir + 'line_item_chgs_' + res + '_models.tex', 'w') as tf:
        tf.write(table.to_latex())

    return table

def makeEmptyPlot(x_lab='Detected Underreportaxpayer_idg ($ Millions)',
                  y_lab='Disparity (percentage points)',
                  out='/REDACTED/audit_descr_stats/data/final_paper_figs/'):

    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    fig.suptitle('', fontsize=24)
    plt.xlabel(x_lab, fontsize=20)
    plt.ylabel(y_lab, fontsize=20)
 
    plt.savefig(out+'empty_fig.png')
    
    return None

def makeTrajectoryPlot(modelnames,
                        quantities,
                        status_quo_point,
                        colors,
                        linestyles,
                        trajname,
                        plot_sq=True,
                        dict_x_varb='bootstrap_revenue',
                        dict_y_varb='full_chen_fair',
                        dict_se_varb='full_chen_fair',
                        sq_y_varb='eitc_chen_fair',
                        x_lab='Detected Underreportaxpayer_idg ($ Millions)',
                        y_lab='Disparity (percentage points)',
                        activity_code=None,
                        out=defaultout,
                        eitc=False,
                        offset=False,
                        annotatePct=True,
                        annotateModelName=False,
                        annotationcoords=[[0,0]],
                        ref_pt_offset=[False],
                        sq_coords_eitc=[500,-0.001],
                        sq_coords_pop=[1000,-0.0005],
                        plot_error = True,
                        title_lab = ''):
    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    fig.suptitle('', fontsize=24)
    plt.xlabel(x_lab, fontsize=20)
    plt.ylabel(y_lab, fontsize=20)
    ## plot oracle
    #xs =  revenues
    xs = [quantities[i][dict_x_varb + '_mean'] for i in range(len(quantities))]
    #ys = disparities
    ys = [quantities[i][dict_y_varb + '_mean'] for i in range(len(quantities))]
    errs = [quantities[i][dict_se_varb + '_std'] for i in range(len(quantities))]
    models = modelnames
    label = [(x/100) for x in range(1,301)]
    ax.set_ylabel(y_lab, fontsize=20)
    ax.set_xlabel(x_lab, fontsize=20)
    ax.set_title(str(title_lab), fontsize=20)
    ax.tick_params(labelsize=16)
    ax.axhline(y=0, color='gray')
    if eitc == True and activity_code is None and plot_sq is True:
        x = status_quo_point['eitc_rev']/1000000
        y = status_quo_point['eitc_' + sq_y_varb]
    elif eitc == False and activity_code is None and plot_sq is True:
        x = status_quo_point['overall_rev']/1000000
        y = status_quo_point['overall_' + sq_y_varb]
    elif activity_code is not None and plot_sq is True:
        x = status_quo_point[str(activity_code) + '_rev']/1000000
        y = status_quo_point[str(activity_code) + '_' + sq_y_varb]
    #plt.plot([plt.xlim()[0],x],[y,y],color='red',linestyle='--')
    #plt.plot([x,x],[0,y],color='red',linestyle='--')
    if plot_sq is True:
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    
    ## scale up y axis to percentages (instead of decimals)
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)
    
    if eitc == True and activity_code is None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['eitc_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_eitc[0], y+sq_coords_eitc[1]), size=16)
    elif eitc == False and activity_code is None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['overall_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    elif activity_code is not None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    for i in range(len(models)):
        #x = np.array(quantities[i]['rev_mean'])
        #y = np.array(quantities[i]['fair_mean'])
        x = np.array(xs[i])
        print(x)
        y = np.array(ys[i])
        print(y)
        #errs = 
        ax.plot(x, y, label=models[i], color=colors[i], linestyle=linestyles[i])
        yerr = np.array(errs[i])
        yerr = yerr.astype(np.float)
        if plot_error == True:
            ax.fill_between(x, y-(yerr*1.96), y+(yerr*1.96), alpha=0.2, edgecolor='k', facecolor=colors[i])
        for j, txt in enumerate(label):
            if (j+1)%100==0 and offset==False:
                ax.annotate(str(math.floor(txt)) + '%', xy=(x[j]+50,y[j]), fontsize=16)
                ax.plot(x[j],y[j],'ko',markersize=5)
            elif (j+1)%100==0 and offset==True:
                dist=(float(i)%2)*0.001
                print(dist)
                ax.annotate(str(math.floor(txt)) + '%', xy=(x[j]+50,y[j]-dist), fontsize=16)
                ax.plot(x[j],y[j],'ko',markersize=5)
        if annotatePct and eitc == True and activity_code is None:
            idx=math.floor(round(status_quo_point['eitc_budget'],4)*10000 - 1)
            ax.plot(x[idx],y[idx], 'ko', markersize=10)
            if ref_pt_offset[i]==True:
                ax.annotate(str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]+75, y[idx]-0.0015), size=16)
            else:
                ax.annotate(str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]+75, y[idx]+0.00025), size=16)
        elif annotatePct and eitc == False and activity_code is None:
            idx=math.floor(round(status_quo_point['overall_budget'],4)*10000 - 1)
            ax.plot(x[idx],y[idx],'ko',markersize=10)
            if ref_pt_offset[i] == True:
                ax.annotate(str(round(status_quo_point['overall_budget']*100, 2)) + '%',xy=(x[idx],y[idx]),xytext=(x[idx]+100,y[idx]-0.001),size=16) 
            else:
                ax.annotate(str(round(status_quo_point['overall_budget']*100, 2)) + '%',xy=(x[idx],y[idx]),xytext=(x[idx]+100,y[idx]+0.00075),size=16) 
        elif annotatePct and activity_code is not None:
            idx=math.floor(round(status_quo_point[str(activity_code) + '_budget']*10000 -1, 2))
            ax.plot(x[idx],y[idx], 'ko', markersize=10)
            if ref_pt_offset[i]==True:
                ax.annotate(str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]-200, y[idx]-0.001), size=16)
            else:
                ax.annotate(str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]-200, y[idx]+0.00075), size=16)
        if annotateModelName:
            #pdb.set_trace()
            ax.annotate(modelnames[i],(annotationcoords[i][0],annotationcoords[i][1]), size=20)
    plt.savefig(out+trajname+'.png')
    return fig
    
def getOneSidedPValue(quantities, status_quo_point, mean_varb='bootstrap_chen_fair', std_varb='bootstrap_chen_fair', bootstrap_iters=100):

    ''' 
    quantities: plot dictionaries to use.
    status_quo_point: dictionary with status quo audit rate
    mean_varb: variable name for the mean value
    std_varb: variable name for the standard error
    '''

    idx=math.floor(status_quo_point['eitc_budget']*10000)-1
    means = [quantities[i][mean_varb + '_mean'][idx] for i in range(len(quantities))]
    errs = [quantities[i][std_varb + '_std'][idx] for i in range(len(quantities))]

    t_score = stats.ttest_ind_from_stats(mean1=means[0], 
                                        std1=errs[0]*math.sqrt(bootstrap_iters), 
                                        nobs1=bootstrap_iters,
                                        mean2=means[1],
                                        std2=errs[1]*math.sqrt(bootstrap_iters),
                                        nobs2=bootstrap_iters)

    return t_score[1]/2

def makeDotPlot(modelnames,
                quantities,
                status_quo_point,
                colors,
                plotname,
                plot_sq=True,
                activity_code=None,
                x_varb='bootstrap_revenue',
                x_varb_sq='rev',
                x_lab='Detected Underreportaxpayer_idg ($ Millions)',
                y_varb='bootstrap_fair',
                y_varb_sq='chen_fair',
                se_varb='bootstrap_fair',
                y_lab='Black/Non-Black Disparity (percentage points)',
                demillion_x=True,
                demillion_y=False,
                scale_x=False,
                plot_ci = True,
                dotplot_symbol = None,
                budget='auto',
                out=defaultout,
                marker_size=10,
                eitc=True,
                offsets=[[0,0]],
                annotateModelName=True,
                annotationcoords=[[0,0]],
                sq_coords_eitc=[500,-0.001],
                sq_coords_pop=[1000,-0.0005]):

    ## Dot plot version of our trajectory plotters (for comparing across >3 models)

    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    fig.suptitle('', fontsize=24)
    plt.xlabel(x_lab, fontsize=20)
    plt.ylabel(y_lab, fontsize=20)
    ## define coords for plot
    if budget == 'auto' and eitc == True and activity_code == None:
        idx=math.floor(status_quo_point['eitc_budget']*10000)-1
        xs = [quantities[i][x_varb + '_mean'][idx] for i in range(len(quantities))]
        x_errs = [quantities[i][x_varb + '_std'][idx] for i in range(len(quantities))]
        ys = [quantities[i][y_varb + '_mean'][idx] for i in range(len(quantities))]
        y_errs = [quantities[i][se_varb + '_std'][idx] for i in range(len(quantities))]
    elif budget == 'auto' and eitc == False and activity_code == None:
        aud_rate=math.floor(status_quo_point['overall_budget']*10000)-1
        xs = [quantities[i][x_varb + '_mean'][idx] for i in range(len(quantities))]
        x_errs = [quantities[i][x_varb + '_std'][idx] for i in range(len(quantities))]
        ys = [quantities[i][y_varb + '_mean'][idx] for i in range(len(quantities))]
        y_errs = [quantities[i][se_varb + '_std'][idx] for i in range(len(quantities))]
    elif activity_code is not None:
        idx=math.floor(status_quo_point[str(activity_code) + '_budget']*10000)-1
        xs = [quantities[i][x_varb + '_mean'][idx] for i in range(len(quantities))]
        x_errs = [quantities[i][x_varb + '_std'][idx] for i in range(len(quantities))]
        ys = [quantities[i][y_varb + '_mean'][idx] for i in range(len(quantities))]
        y_errs = [quantities[i][se_varb + '_std'][idx] for i in range(len(quantities))]
    else:
        raise Exception('Benchmark budgets now auto-fill in code. Set budget argument to \'auto\'.')
    
    models = modelnames
    label = str(budget) + '%'
    ax.set_xlabel(x_lab, fontsize=20)
    ax.set_ylabel(y_lab, fontsize=20)
    ax.set_title('', fontsize=20)
    ax.tick_params(labelsize=16)
    ax.axhline(y=0, color='black')
    if eitc == True and plot_sq == True:
        if demillion_x == True:
            x = status_quo_point['eitc_' + x_varb_sq]/1000000
        elif demillion_x == False:
            x = status_quo_point['eitc_' + x_varb_sq]
        if demillion_y == True:
            y = status_quo_point['eitc_' + y_varb_sq]/1000000
        elif demillion_y == False:
            y = status_quo_point['eitc_' + y_varb_sq]
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    elif eitc == False and plot_sq == True:
        if demillion_x == True:
            x = status_quo_point['overall_' + x_varb_sq]/1000000
        else:
            x = status_quo_point['overall_' + x_varb_sq]
        if demillion_y == True:
            y = status_quo_point['overall_' + y_varb_sq]/1000000
        else:
            y = status_quo_point['overall_' + y_varb_sq]
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    
    ## scale up y axis to percentages (instead of decimals)
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)

    ## scale x axis too if needed:
    if scale_x==True and demillion==False:
        scale_x=0.01
        ticks_x=ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x/scale_x))
        ax.xaxis.set_major_formatter(ticks_x)

    if eitc == True and plot_sq == True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['eitc_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_eitc[0], y+sq_coords_eitc[1]), size=16)
    elif eitc == False and plot_sq == True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['overall_budget']*100, 2)) + '\% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
        
    for i in range(len(models)):
        #x = np.array(quantities[i]['rev_mean'])
        #y = np.array(quantities[i]['fair_mean'])
        x = np.array(xs[i])
        y = np.array(ys[i])
        
        if dotplot_symbol is None: 
            ax.plot(x, y, 'o', label=models[i], color=colors[i], markersize=marker_size)
        if dotplot_symbol is not None:    
            ax.plot(x, y, marker =  dotplot_symbol[i], label=models[i], color=colors[i], markersize=marker_size)
        x_err = np.array(x_errs[i])
        y_err = np.array(y_errs[i])
        
        if plot_ci:    
            ax.errorbar(x,y,xerr=x_err*1.96,yerr=y_err*1.96, ecolor=colors[i], capsize=(marker_size+2))    
       
        if annotateModelName:
            ax.annotate(modelnames[i],(annotationcoords[i][0],annotationcoords[i][1]), size=16)
            
    plt.savefig(out+plotname+'.png')
    
    return fig
    
def twoGroupBarPlotOld(share_x=False, 
             share_y=True, 
             fig_size=(14,8), 
             sup_title='Activity Code Shares for Oracle and Random Forest',
             sup_title_fontsize=24,
             y_label='Share of Audits',
             plot_titles=['Status Quo', 'Oracle', 'Random Forest'],
             fig_fontsize=20,
             label_size=16,
             alphas=[1,0.3],
             group_colors=['accent blue', 'accent blue'],
             group_labels=['270', '271'],
             right_label_only=True,
             label_loc='upper right',
             sq_data_bottom=[],
             sq_data_top=[],
             oracle_data_bottom=[],
             oracle_data_top=[],
             rf_data_bottom=[],
             rf_data_top=[],
             output_filename='test.png',
             out=defaultout):
             
    ### function that creates two-group stacked bar plots of activity code or
    ### group allocations for status quo, oracle, and random forest.
             
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=share_x, sharey=share_y, figsize=fig_size, gridspec_kw={'width_ratios': [1, 1, 1]})
    fig.suptitle(sup_title, fontsize=sup_title_fontsize)
    
    ax1.set_ylabel(y_label,fontsize=fig_fontsize)
    ax1.set_title(plot_titles[0], fontsize=fig_fontsize)
    ax1.tick_params(labelsize=label_size)
    
    ax2.set_title(plot_titles[1], fontsize=fig_fontsize)
    ax2.tick_params(labelsize=label_size)
    ax2.tick_params(which='both', labelleft=True)
    
    ax3.set_title(plot_titles[2], fontsize=fig_fontsize)
    ax3.tick_params(labelsize=label_size)
    ax3.tick_params(which='both', labelleft=True)
    
    ax1.bar(" ", sq_data_bottom, label=group_labels[0], color=group_colors[0], alpha=alphas[0])
    ax1.bar(" ", sq_data_top, bottom=sq_data_bottom, label=group_labels[1], color=group_colors[1], alpha=alphas[1])
    
    ax2.bar(" ", oracle_data_bottom, label=group_labels[0], color=group_colors[0], alpha=alphas[0])
    ax2.bar(" ", oracle_data_top, bottom=oracle_data_bottom, label=group_labels[1], color=group_colors[1], alpha=alphas[1])
        
    ax3.bar(" ", rf_data_bottom, label=group_labels[0], color=group_colors[0], alpha=alphas[0])
    ax3.bar(" ", rf_data_top, bottom=rf_data_bottom, label=group_labels[1], color=group_colors[1], alpha=alphas[1])
    
    handles, leg_labels = ax3.get_legend_handles_labels()
    
    if right_label_only==True:
        ax3.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
    else:
        ax1.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
        ax2.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
        ax3.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
        
    ax1.tick_params(bottom=False)
    ax2.tick_params(bottom=False)
    ax3.tick_params(bottom=False)
    
    plt.show()
    
    plt.savefig(out+output_filename)
    
    print('output saved successfully as ' + output_filename)
    
    return fig
    

def getAllQuantities(plot_dict, bootstrap=False):
    quants = {}
    stream = open('/REDACTED/horizontal/fairness/code/rf/config/data-config.yaml', 'r')
    out = yaml.safe_load(stream)
    print('config file loaded')
    if bootstrap == True:
        varbs = out['full_data_outcomes'] + out['bootstrap_outcomes'] + out['5_fold_outcomes']
    else:
        varbs = out['full_data_outcomes'] + out['5_fold_outcomes']

    for varb in varbs:
        quants[varb + '_mean'] = get_meas_varb(plot_dict,'mean',varb)
        quants[varb + '_std'] = get_meas_varb(plot_dict,'std',varb)
    for ac in list(range(270,282)):
        quants[str(ac)] = get_meas_varb(plot_dict, ac_shares=True, activity_code=ac)
    return quants


def twoGroupBarPlot(share_x=False, 
             share_y=True, 
             fig_size=(14,8), 
             sup_title='Activity Code Shares for Oracle and Random Forest',
             sup_title_fontsize=24,
             y_label='Share of Audits',
             plot_titles=['Status Quo', 'Oracle', 'Random Forest'],
             fig_fontsize=20,
             label_size=16,
             alphas=[1,0.3],
             group_colors=['accent blue', 'accent blue'],
             group_labels=['270', '271'],
             right_label_only=True,
             label_loc='upper right',
             sq_data_bottom=[],
             sq_data_top=[],
             oracle_data_bottom=[],
             oracle_data_top=[],
             rf_data_bottom=[],
             rf_data_top=[],
             output_filename='test.png',
             out=defaultout,
             usageFlags = [True,True,True]):
             
    ### function that creates two-group stacked bar plots of activity code or
    ### group allocations for status quo, oracle, and random forest.
    n_subplots = sum(usageFlags)         
    fig, axes =  plt.subplots(1, n_subplots, sharex=share_x, sharey=share_y, figsize=fig_size, gridspec_kw={'width_ratios': [1 for i in range(n_subplots)]})
    fig.suptitle(sup_title, fontsize=sup_title_fontsize)
    if n_subplots<=1:
        axes = [axes]
    for i in range(len(axes)):
        if i == 0:
            axes[i].set_ylabel(y_label,fontsize=fig_fontsize)
        axes[i].set_title(plot_titles[i],fontsize=fig_fontsize)
        axes[i].tick_params(labelsize=label_size)
        if i>0:
            axes[i].tick_params(which='both',labelleft=True)
    
    #pdb.set_trace() 
    datas = [[sq_data_bottom,sq_data_top],[oracle_data_bottom,oracle_data_top],[rf_data_bottom,rf_data_top]]
    n_plotted = 0
    for i in range(len(datas)):
        if usageFlags[i]:
            axes[n_plotted].bar(" ", datas[i][0], label=group_labels[0],color=group_colors[0],alpha=alphas[0])
            axes[n_plotted].bar(" ", datas[i][1], bottom=datas[i][0], label=group_labels[1], color = group_colors[1],alpha=alphas[1])
            n_plotted = n_plotted+1
    handles, leg_labels = axes[-1].get_legend_handles_labels()
    
    if right_label_only==True:
        axes[-1].legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
    else:
        for i in range(len(axes)):
            axes[i].legend(handles[::-1],leg_labels[::-1],fontsize=label_size,loc=label_loc)
    for ax in axes:
        ax.tick_params(bottom=False) 
    
    plt.show()
    
    plt.savefig(out+output_filename)
    
    print('output saved successfully as ' + output_filename)
    
    return fig



        
def manyGroupBarPlotGeneral(
             share_x=False, 
             share_y=True, 
             fig_size=(14,8), 
             sup_title='Activity Code Shares for Oracle and Random Forest',
             sup_title_fontsize=24,
             y_label='Share of Audits',
             plot_titles=['Status Quo', 'Oracle', 'Random Forest'],
             fig_fontsize=20,
             label_size=16,
             cmap=mcolors.LinearSegmentedColormap.from_list("", ["accent blue", "background blue", "cardinal red"]),
             pct_idx=6,
             alpha=0.3,
             right_label_only=True,
             label_loc='upper right',
             sq_dict={},
             oracle_dict={},
             rf_dict={},
             output_filename='test.png',
             out=defaultout,
             usageFlags=[True,True,True],
             markEITC=True):
                 
    ### function that creates multi-group (n>=3) stacked bar plots of activity code
    ### allocations for status quo, oracle, and random forest.
   
    fig, axes = plt.subplots(1, sum(usageFlags), sharex=share_x, sharey=share_y, figsize=fig_size, gridspec_kw={'width_ratios': [1, 1, 1]})
    fig.suptitle(sup_title, fontsize=sup_title_fontsize)
    for i in range(len(axes)):
        if i==0:
            axes[i].set_ylabel(y_label,fontsize=fig_fontsize)

        axes[i].set_title(plot_titles[i],fontsize=fig_fontsize)
        axes[i].tick_params(labelsize=label_size)
        if i>1:
            axes[i].tick_params(which='both',labelleft=True)

    dicts = [sq_dict,oracle_dict,rf_dict]
    #need to rename to dict1, dict2, dict3
    for ac in list(range(270,282)):
        if ac==270:
            nplotted=0
            for j in range(len(dicts)):
                if usageFlags[j]==1:
                    axes[nplotted].bar(' ', dicts[j][str(ac)][pct_idx],label=str(ac),color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
                    nplotted = nplotted +1 
        else:
            nplotted =0
            bottoms = [0,0,0]
            for j in range(len(dicts)):
                if usageFlags[j]:
                    for code in list(range(270,ac)):
                        bottoms[j] += dicts[j][str(code)][pct_idx]
                    axes[nplotted].bar(' ', dicts[j][str(ac)][pct_idx],bottom=bottoms[j],label=str(ac),color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
                    if ac==272 and markEITC:
                        axes[nplotted].axhline(y=bottoms[j],color='black',linewidth=5)
                    nplotted = nplotted + 1
    
    handles, leg_labels = axes[-1].get_legend_handles_labels()
    
    if right_label_only==True:
        axes[-1].legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
    else:
        for j in range(len(axes)):
            axes[j].legend(handles[::-1],leg_labels[::-1],fontsize=label_size,loc=label_loc)
    for ax in axes:
        ax.tick_params(bottom=False)    
    plt.show()
    plt.savefig(out+output_filename)
    print('output saved successfully as ' + output_filename)
    return fig

        
def manyGroupBarPlot(
             share_x=False, 
             share_y=True, 
             fig_size=(14,8), 
             sup_title='Activity Code Shares for Oracle and Random Forest',
             sup_title_fontsize=24,
             y_label='Share of Audits',
             plot_titles=['Status Quo', 'Oracle', 'Random Forest'],
             fig_fontsize=20,
             label_size=16,
             cmap=mcolors.LinearSegmentedColormap.from_list("", ["accent blue", "background blue", "cardinal red"]),
             pct_idx=6,
             alpha=0.3,
             right_label_only=True,
             label_loc='upper right',
             sq_dict={},
             oracle_dict={},
             rf_dict={},
             output_filename='test.png',
             out=defaultout):
                 
    ### function that creates multi-group (n>=3) stacked bar plots of activity code
    ### allocations for status quo, oracle, and random forest.
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=share_x, sharey=share_y, figsize=fig_size, gridspec_kw={'width_ratios': [1, 1, 1]})
    fig.suptitle(sup_title, fontsize=sup_title_fontsize)
    
    ax1.set_ylabel(y_label,fontsize=fig_fontsize)
    ax1.set_title(plot_titles[0], fontsize=fig_fontsize)
    ax1.tick_params(labelsize=label_size)
      
    ax2.set_title(plot_titles[1], fontsize=fig_fontsize)
    ax2.tick_params(labelsize=label_size)
    ax2.tick_params(which='both', labelleft=True)
    
    ax3.set_title(plot_titles[2], fontsize=fig_fontsize)
    ax3.tick_params(labelsize=label_size)
    ax3.tick_params(which='both', labelleft=True)
    
    for ac in list(range(270,282)):
        if ac==270:
            ax1.bar(' ', sq_dict[str(ac)][pct_idx], label=str(ac), color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
            ax2.bar(' ', oracle_dict[str(ac)][pct_idx], label=str(ac), color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
            ax3.bar(' ', rf_dict[str(ac)][pct_idx], label=str(ac), color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
        else:
            sq_bottom=0
            oracle_bottom=0
            rf_bottom=0
            for code in list(range(270,ac)):
                sq_bottom+=sq_dict[str(code)][pct_idx]
                oracle_bottom+=oracle_dict[str(code)][pct_idx]
                rf_bottom+=rf_dict[str(code)][pct_idx]
            ax1.bar(' ', sq_dict[str(ac)][pct_idx], bottom=sq_bottom, label=str(ac), color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
            ax2.bar(' ', oracle_dict[str(ac)][pct_idx], bottom=oracle_bottom, label=str(ac), color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
            ax3.bar(' ', rf_dict[str(ac)][pct_idx], bottom=rf_bottom, label=str(ac), color=cmap(list(range(270,282)).index(ac)/len(list(range(270,282)))))
        
    handles, leg_labels = ax3.get_legend_handles_labels()
    
    if right_label_only==True:
        ax3.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
    else:
        ax1.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
        ax2.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
        ax3.legend(handles[::-1], leg_labels[::-1], fontsize=label_size, loc=label_loc)
        
    ax1.tick_params(bottom=False)
    ax2.tick_params(bottom=False)
    ax3.tick_params(bottom=False)
    
    plt.show()
    
    plt.savefig(out+output_filename)
    
    print('output saved successfully as ' + output_filename)
    
    return fig

